R

[R语言] 一种特殊的lasso求解

这是一篇关于提高计算速度的技术文档

Posted by Leung ZhengHua on 2017-11-16

本文总点击量

最近开始忙的要死,一大堆东西堆在一起但又不想做的感觉。今晚研究了Rcpp发现又有新的收获,Rcpp用来做一些简单的循环和判断还可以,涉及矩阵运算时难道还要逐行逐列相乘相加吗?想想都要放弃了,直到我遇上了RcppArmadillo包。它是C++里面专门做科学运算的,用法类似Matlab。有了这个包,我终于找到了一个更有效更容易提高计算速度的办法,在尊贵的阿苏斯上的测试对比发现,Rcpp:Matlab:R的速度比例是约为1:2:10。想要对比这三者的想法起源于统计计算的老师布置的作业,参考一份Matlab脚本先写了一份R的函数,测试之后发现速度贼慢,然后翻开Dirk Eddelbuettel的无缝整合一书,找到一套办法。

问题引入

在lasso问题中:

(1)min||(yxβ)||2+λ||β||s.t. βi0

在上面的约束下,可以化简目标函数为:

(2)f0=(yxβ)(yxβ)+λ1β=yyyxββxy+βxxβ+λ1β=yy2yxβ+βxxβ+λ1β=yy+βxxβ+(λ12yx)β

从上式可以看到,yy是一个给定的常数,有人写了一篇文章,证明了这个目标函数的优化可以通过下面的方法完成:

b=(λ12yx)a=[xx]+βc=[xx]β,其中[xx]+表示矩阵正部,矩阵中的负数设为0;[xx]表示矩阵负部,矩阵中的正数设为0,仅保留负数的绝对值。这个方法核心的关键就是解方程ax2+bxc=0,得到对称轴右边的解乘以原来的β就作为新的β估计值。具体步骤为:

  • β0=1,得到f0
  • β0=1代入a=[xx]+βc=[xx]β
  • 利用点乘计算新解β1=b+b2+4ac2aβ0
  • β1代入目标函数得到f1
  • 比较f1f0,若abs(f1f0)<eps则停止迭代;否则,令f0=f1,重复3、4步

那么我们现在来比较不同方法的代码吧:

Matlab

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
%% This code is used to solve the following constrained optimization problem
% \| Y - X* theta\|^2 + lambda sum_k theta_k, s.t. theta_k >=0
%% Input
% x,y -- data
% lam -- lambda
% ep -- error for critia to break the iteration
%% Code
function [output,k] = multi_update(x,y,lam_1,lam_2,ep)
[n,p] = size(x);
A = 2*(x'*x+lam_2*eye(p));
b = lam_1*ones(p,1)-2*x'*y;
theta0 = ones(p,1); % initial value for theta
f0 = 1/2*theta0'*A*theta0+b'*theta0+y'*y; % initial objective value
Ap = (abs(A)+A)/2; % positive part of A
An = (abs(A)-A)/2; % negative part of A
%% one step update
aa = Ap*theta0;
cc = An*theta0;
% aa(abs(aa)<ep) = ep; % why???
theta = (-b+sqrt(b.*b+4*aa.*cc))./(2*aa).*theta0; % update of the solution
theta(isnan(theta)) = 0; % when aa =0 ,theta should be 0, not NaN
f1 = 1/2*theta'*A*theta+b'*theta+y'*y; % new value of obj after the update of the parameter
%% more step undate
k=1;
while(f0>ep && abs((f1-f0))>ep) %收敛准则如果取百分比不超过ep会很快结束
f0 = f1;
theta0 = theta;
aa = Ap*theta0;
cc = An*theta0;
% aa(abs(aa)<ep) = ep; % why??
theta = (-b+sqrt(b.*b+4*aa.*cc))./(2*aa).*theta0;
theta(isnan(theta)) = 0; % when aa =0 ,theta should be 0, not NaN
f1 = 1/2*theta'*A*theta+b'*theta+y'*y;
k=k+1;
end
output = theta;
k=k;

R

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
Update<-function(x,y,lamb,bound){
n <- dim(x)[1];p <- dim(x)[2]
A = t(x)%*%x+diag(p)
v <- matrix(rep(1,p),p)
lam=NULL
lam<- matrix(rep(lamb,p),p)
f0 <- ((1/2)*t(v)%*%A%*%v- t(v)%*%t(x)%*%y
+(1/2)*t(y)%*%y + t(lam)%*%v)
aa = (abs(A)+A)/2
cc = (abs(A)-A)/2
a=aa %*% v
c=cc %*% v
b = lam - t(x)%*%y
v = (-b + sqrt(b^2 +4*a*c))/(2*a)*v
f1 <- ((1/2)*t(v)%*%A%*%v- t(v)%*%t(x)%*%y
+(1/2)*t(y)%*%y + t(lam)%*%v)
k = 1
while(abs(f0-f1)>bound){
f0 <- f1
a=aa %*% v
c=cc %*% v
v = (-b + sqrt(b^2 +4*a*c))/(2*a)*v
f1 <- ((1/2)*t(v)%*%A%*%v- t(v)%*%t(x)%*%y
+(1/2)*t(y)%*%y + t(lam)%*%v)
k=k+1
f0-f1
}
v[v <=1e-2] = 0
return(list(beta = v,k = k))
}

Rcpp

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>
#include <Rcpp.h>
using namespace Rcpp;
using namespace arma;
// [[Rcpp::export]]
List timesTwo(mat x,mat y,double lamb,double bound) {
int p=x.n_cols;
mat A(p,p);
//NumericMatrix A=t(x)%*%x+diag(p);
A=x.t()*x+lamb * eye<mat>(p,p); //mat A(p,n)=0; wrong! set lamb to keep identical with matlab
// v <- matrix(rep(1,p),p)
mat v = ones<mat>(p,1);
// lam<- matrix(rep(lamb,p),p)
mat lam=lamb * ones<mat>(p,1);
// f0 <- ((1/2)*t(v)%*%A%*%v- t(v)%*%t(x)%*%y
// +(1/2)*t(y)%*%y + t(lam)%*%v)
// mat f0=((1/2)*v.t()*A*v-v.t()*x.t()*y+1/2*y.t()*y+lam.t()*v);
mat f0=((1/2)*v.t()*A*v-v.t()*x.t()*y+lam.t()*v);
// aa = (abs(A)+A)/2
mat aa=(abs(A)+A)/2;
// cc = (abs(A)-A)/2
mat cc=(abs(A)-A)/2;
// a=aa %*% v
mat a=aa*v;
// c=cc %*% v
mat c=cc*v;
// b = lam - t(x)%*%y
mat b=lam-x.t()*y;
// v = (-b + sqrt(b^2 +4*a*c))/(2*a)*v
// v.print("before");
v=(-b+sqrt(b%b+4*a%c))/(2*a)%v;
//v.print("after");
// f1 <- ((1/2)*t(v)%*%A%*%v- t(v)%*%t(x)%*%y
// +(1/2)*t(y)%*%y + t(lam)%*%v)
// mat f1 =((1/2)*v.t()*A*v-v.t()*x.t()*y+1/2*y.t()*y+lam.t()*v);
mat f1 =((1/2)*v.t()*A*v-v.t()*x.t()*y+lam.t()*v);
// k = 1
int k=1;
// while(abs(f0-f1)>bound){
// f0 <- f1
//
// a=aa %*% v
// c=cc %*% v
// v = (-b + sqrt(b^2 +4*a*c))/(2*a)*v
// f1 <- ((1/2)*t(v)%*%A%*%v- t(v)%*%t(x)%*%y
// +(1/2)*t(y)%*%y + t(lam)%*%v)
//
// k=k+1
// f0-f1
// }
//
while(fabs(f0(0,0)-f1(0,0))>bound) {
f0(0,0)=f1(0,0);
a=aa*v;
c=cc*v;
//(a/10000).print("a"); // must double "
//(b/10000).print("b");
//(c/10000).print("c");
v=(-b+sqrt(b%b+4*a%c))/(2*a)%v;
// f1=((1/2)*v.t()*A*v-v.t()*x.t()*y+1/2*y.t()*y+lam.t()*v);
f1=((1/2)*v.t()*A*v-v.t()*x.t()*y+lam.t()*v);
k=k+1;
//printf("%g\n",fabs(f0(0,0)-f1(0,0))); //abs() return int type
}
// v[v <=1e-2] = 0
for(int i=0;i<p;i++) {if(v[i]<0.001){v[i]=0;}}
//return(list(beta = v,k = k))
return List::create(_["beta"]=v,_["k"]=k);
}

数据来自于标普500和48个权重股的收益率数据,样本容量129(条观测),while循环迭代次数280+万次,耗时:R 700秒,Matlab 140秒,Rcpp 70秒。

生活就像海洋,只有意志坚定的人才能到达彼岸